{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "LLddJfAgbtmi"
      },
      "source": [
        "##### Copyright 2019 The TF-Agents Authors."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "HqslkUeyEJFg"
      },
      "source": [
        "# Tutorial on Multi Armed Bandits in TF-Agents"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "MimUC9NrYFaS"
      },
      "source": [
        "### Get Started\n",
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/agents/blob/master/tf_agents/bandits/colabs/bandits_tutorial.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/agents/blob/master/tf_agents/bandits/colabs/bandits_tutorial.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "O7gLdUS6b2EG"
      },
      "source": [
        "### Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "3oCS94Z83Jo2"
      },
      "outputs": [],
      "source": [
        "import abc\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "\n",
        "from tf_agents.agents import tf_agent\n",
        "from tf_agents.drivers import driver\n",
        "from tf_agents.environments import py_environment\n",
        "from tf_agents.environments import tf_environment\n",
        "from tf_agents.environments import tf_py_environment\n",
        "from tf_agents.policies import tf_policy\n",
        "from tf_agents.specs import array_spec\n",
        "from tf_agents.specs import tensor_spec\n",
        "from tf_agents.trajectories import time_step as ts\n",
        "from tf_agents.trajectories import trajectory\n",
        "from tf_agents.trajectories import policy_step\n",
        "\n",
        "# Clear any leftover state from previous colabs run.\n",
        "# (This is not necessary for normal programs.)\n",
        "tf.reset_default_graph()\n",
        "\n",
        "tf.compat.v1.enable_resource_variables()\n",
        "tf.compat.v1.enable_v2_behavior()\n",
        "nest = tf.compat.v2.nest"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "CcIob6rYqien"
      },
      "source": [
        "# Introduction\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "JdnTJrzaeft3"
      },
      "source": [
        "The Multi-Armed Bandit problem (MAB) is a special case of Reinforcement Learning: an agent collects rewards in an environment by taking some actions after observing some state of the environment. The main difference between general RL and MAB is that in MAB, we assume that the action taken by the agent does not influence the next state of the environment. Therefore, agents do not model state transitions, credit rewards to past actions, or \"plan ahead\" to get to reward-rich states.\n",
        "\n",
        "As in other RL domains, the goal of a MAB *agent* is to find a *policy* that collects as much reward as possible. It would be a mistake, however, to always try to exploit the action that promises the highest reward, because then there is a chance that we miss out on better actions if we do not explore enough. This is the main problem to be solved in (MAB), often called the *exploration-exploitation dilemma*.\n",
        "\n",
        "\n",
        "\n",
        "Bandit environments, policies, and agents for MAB can be found in subdirectories of [tf_agents/bandits](https://github.com/tensorflow/agents/blob/master/tf_agents/bandits)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "iPzsBCTperx3"
      },
      "source": [
        "# Environments"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "1LOXW8i320Cp"
      },
      "source": [
        "In TF-Agents, the environment class serves the role of giving information on the current state (this is called **observation** or **context**), receiving an action as input, performing a state transition, and outputting a reward. This class also takes care of resetting when an episode ends, so that a new episode can start. This is realized by calling a `reset` function when a state is labelled as \"last\" of the episode.\n",
        "\n",
        "For more details, see the [TF-Agents environments tutorial](https://github.com/tensorflow/agents/blob/master/tf_agents/colabs/2_environments_tutorial.ipynb).\n",
        "\n",
        "As mentioned above, MAB differs from general RL in that actions do not influence the next observation. Another difference is that in Bandits, there are no \"episodes\": every time step starts with a new observation, independently of previous time steps.\n",
        "\n",
        "To make sure observations are independent and to abstract away the concept of RL episodes, we introduce subclasses of `PyEnvironment` and `TFEnvironment`: [BanditPyEnvironment](https://github.com/tensorflow/agents/blob/master/tf_agents/bandits/environments/bandit_tf_environment.py) and [BanditTFEnvironment](https://github.com/tensorflow/agents/blob/master/tf_agents/bandits/environments/bandit_py_environment.py). These classes expose two private member functions that remain to be implemented by the user:\n",
        "\n",
        "```python\n",
        "@abc.abstractmethod\n",
        "def _observe(self):\n",
        "```\n",
        "and\n",
        "```python\n",
        "@abc.abstractmethod\n",
        "def _apply_action(self, action):\n",
        "```\n",
        "The `_observe` function returns an observation. Then, the policy chooses an action based on this observation. The `_apply_action` receives that action as an input, and returns the corresponding reward. These private member functions are called by the functions `reset` and `step`, respectively."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "TTaG2ZapQvHX"
      },
      "outputs": [],
      "source": [
        "class BanditPyEnvironment(py_environment.PyEnvironment):\n",
        "\n",
        "  def __init__(self, observation_spec, action_spec):\n",
        "    self._observation_spec = observation_spec\n",
        "    self._action_spec = action_spec\n",
        "    super(BanditPyEnvironment, self).__init__()\n",
        "\n",
        "  # Helper functions.\n",
        "  def action_spec(self):\n",
        "    return self._action_spec\n",
        "\n",
        "  def observation_spec(self):\n",
        "    return self._observation_spec\n",
        "\n",
        "  def _empty_observation(self):\n",
        "    return tf.nest.map_structure(lambda x: np.zeros(x.shape, x.dtype),\n",
        "                                 self.observation_spec())\n",
        "\n",
        "  # These two functions below should not be overridden by subclasses.\n",
        "  def _reset(self):\n",
        "    \"\"\"Returns a time step containing an observation.\"\"\"\n",
        "    return ts.restart(self._observe(), batch_size=self.batch_size)\n",
        "\n",
        "  def _step(self, action):\n",
        "    \"\"\"Returns a time step containing the reward for the action taken.\"\"\"\n",
        "    reward = self._apply_action(action)\n",
        "    return ts.termination(self._observe(), reward)\n",
        "\n",
        "  # These two functions below are to be implemented in subclasses.\n",
        "  @abc.abstractmethod\n",
        "  def _observe(self):\n",
        "    \"\"\"Returns an observation.\"\"\"\n",
        "\n",
        "  @abc.abstractmethod\n",
        "  def _apply_action(self, action):\n",
        "    \"\"\"Applies `action` to the Environment and returns the corresponding reward.\n",
        "    \"\"\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ZVtLk28xVo0j"
      },
      "source": [
        "The above interim abstract class implements `PyEnvironment`'s `_reset` and `_step` functions and exposes the abstract functions `_observe` and `_apply_action` to be implemented by subclasses."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "xQbI-6PdtSJn"
      },
      "source": [
        "## A Simple Example Environment Class"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "8qspwAx0tS6l"
      },
      "source": [
        "The following class gives a very simple environment for which the observation is a random integer between -2 and 2, there are 3 possible actions (0, 1, 2), and the reward is the product of the action and the observation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "YV6DhsSi227-"
      },
      "outputs": [],
      "source": [
        "class SimplePyEnvironment(BanditPyEnvironment):\n",
        "\n",
        "  def __init__(self):\n",
        "    action_spec = array_spec.BoundedArraySpec(\n",
        "        shape=(), dtype=np.int32, minimum=0, maximum=2, name='action')\n",
        "    observation_spec = array_spec.BoundedArraySpec(\n",
        "        shape=(1,), dtype=np.int32, minimum=-2, maximum=2, name='observation')\n",
        "    super(SimplePyEnvironment, self).__init__(observation_spec, action_spec)\n",
        "\n",
        "  def _observe(self):\n",
        "    self._observation = np.random.randint(-2, 3, (1,))\n",
        "    return self._observation\n",
        "\n",
        "  def _apply_action(self, action):\n",
        "    return action * self._observation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ipEQgYDIf55t"
      },
      "source": [
        "Now we can use this environment to get observations, and receive rewards for our actions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Eo_uwSz2gAKX"
      },
      "outputs": [],
      "source": [
        "environment = SimplePyEnvironment()\n",
        "observation = environment.reset().observation\n",
        "print(\"observation: %d\" % observation)\n",
        "\n",
        "action = 2 #@param\n",
        "\n",
        "print(\"action: %d\" % action)\n",
        "reward = environment.step(action).reward\n",
        "print(\"reward: %f\" % reward)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "GuVYHI8aDgCx"
      },
      "source": [
        "## TF Environments"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "dP46VwLTDnOR"
      },
      "source": [
        "One can define a bandit environment by subclassing `BanditTFEnvironment`, or, similarly to RL environments, one can define a `BanditPyEnvironment` and wrap it with `TFPyEnvironment`. For the sake of simplicity, we go with the latter option in this tutorial."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "IPPpwSi3EtWz"
      },
      "outputs": [],
      "source": [
        "tf_environment = tf_py_environment.TFPyEnvironment(environment)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "-S9fhxF9GUaT"
      },
      "source": [
        "# Policies"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "NbTt5jnuGlYj"
      },
      "source": [
        "A *policy* in a bandit problem works the same way as in an RL problem: it provides an action (or a distribution of actions), given an observation as input.\n",
        "\n",
        "For more details, see the [TF-Agents Policy tutorial](https://github.com/tensorflow/agents/blob/master/tf_agents/colabs/3_policies_tutorial.ipynb).\n",
        "\n",
        "As with environments, there are two ways to construct a policy: One can create a `PyPolicy` and wrap it with `TFPyPolicy`, or directly create a `TFPolicy`. Here we elect to go with the direct method.\n",
        "\n",
        "Since this example is quite simple, we can define the optimal policy manually. The action only depends on the sign of the observation, 0 when is negative and 2 when is positive."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "VpMZlplNK5ND"
      },
      "outputs": [],
      "source": [
        "class SignPolicy(tf_policy.TFPolicy):\n",
        "  def __init__(self):\n",
        "    observation_spec = tensor_spec.BoundedTensorSpec(\n",
        "        shape=(1,), dtype=tf.int32, minimum=-2, maximum=2)\n",
        "    time_step_spec = ts.time_step_spec(observation_spec)\n",
        "\n",
        "    action_spec = tensor_spec.BoundedTensorSpec(\n",
        "        shape=(), dtype=tf.int32, minimum=0, maximum=2)\n",
        "\n",
        "    super(SignPolicy, self).__init__(time_step_spec=time_step_spec,\n",
        "                                     action_spec=action_spec)\n",
        "  def _distribution(self, time_step):\n",
        "    pass\n",
        "\n",
        "  def _variables(self):\n",
        "    return ()\n",
        "\n",
        "  def _action(self, time_step, policy_state, seed):\n",
        "    observation_sign = tf.cast(tf.sign(time_step.observation[0]), dtype=tf.int32)\n",
        "    action = observation_sign + 1\n",
        "    return policy_step.PolicyStep(action, policy_state)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "GAM7hb4LVQ70"
      },
      "source": [
        "Now we can request an observation from the environment, call the policy to choose an action, then the environment will output the reward:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Z0_5vMDCVZWT"
      },
      "outputs": [],
      "source": [
        "sign_policy = SignPolicy()\n",
        "\n",
        "current_time_step = tf_environment.reset()\n",
        "print('Observation:')\n",
        "print (current_time_step.observation)\n",
        "action = sign_policy.action(current_time_step).action\n",
        "print('Action:')\n",
        "print (action)\n",
        "reward = tf_environment.step(action).reward\n",
        "print('Reward:')\n",
        "print(reward)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "AExuQ7u0-PF6"
      },
      "source": [
        "The way bandit environments are implemented ensures that every time we take a step, we not only receive the reward for the action we took, but also the next observation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "CiB935of-wVv"
      },
      "outputs": [],
      "source": [
        "step = tf_environment.reset()\n",
        "action = 1\n",
        "next_step = tf_environment.step(action)\n",
        "reward = next_step.reward\n",
        "next_observation = next_step.observation\n",
        "print(\"Reward: \")\n",
        "print(reward)\n",
        "print(\"Next observation:\")\n",
        "print(next_observation)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "zFnqVHfeANZP"
      },
      "source": [
        "# Agents"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "1pDK_faXAPSA"
      },
      "source": [
        "Now that we have bandit environments and bandit policies, it is time to also define bandit agents, that take care of changing the policy based on training samples.\n",
        "\n",
        "The API for bandit agents does not differ from that of RL agents: the agent just needs to implement the `_initialize` and `_train` methods, and define a `policy` and a `collect_policy`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "TVCb-vPJOayG"
      },
      "source": [
        "## A More Complicated Environment"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "9Ksv7i7zPGSa"
      },
      "source": [
        "Before we write our bandit agent, we need to have an environment that is a bit harder to figure out. To spice up things just a little bit, the next environment will either always give `reward = observation * action` or `reward = -observation * action`. This will be decided when the environment is initialized."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "fte7-Mr8O0QR"
      },
      "outputs": [],
      "source": [
        "class TwoWayPyEnvironment(BanditPyEnvironment):\n",
        "\n",
        "  def __init__(self):\n",
        "    action_spec = array_spec.BoundedArraySpec(\n",
        "        shape=(), dtype=np.int32, minimum=0, maximum=2, name='action')\n",
        "    observation_spec = array_spec.BoundedArraySpec(\n",
        "        shape=(1,), dtype=np.int32, minimum=-2, maximum=2, name='observation')\n",
        "\n",
        "    # Flipping the sign with probability 1/2.\n",
        "    self._reward_sign = 2 * np.random.randint(2) - 1\n",
        "    print(\"reward sign:\")\n",
        "    print(self._reward_sign)\n",
        "\n",
        "    super(TwoWayPyEnvironment, self).__init__(observation_spec, action_spec)\n",
        "\n",
        "  def _observe(self):\n",
        "    self._observation = np.random.randint(-2, 3, (1,))\n",
        "    return self._observation\n",
        "\n",
        "  def _apply_action(self, action):\n",
        "    return self._reward_sign * action * self._observation[0]\n",
        "\n",
        "two_way_tf_environment = tf_py_environment.TFPyEnvironment(TwoWayPyEnvironment())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7Zb4jWpQUA75"
      },
      "source": [
        "## A More Complicated Policy"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Dz2rEEA1USJu"
      },
      "source": [
        "A more complicated environment calls for a more complicated policy. We need a policy that detects the behavior of the underlying environment. There are three situations that the policy needs to handle:\n",
        "\n",
        "0.   The agent has not detected know yet which version of the environment is running.\n",
        "1.   The agent detected that the original version of the environment is running.\n",
        "2.   The agent detected that the flipped version of the environment is running.\n",
        "\n",
        "We define a `tf_variable` named `_situation` to store this information encoded as values in `[0, 2]`, then make the policy behave accordingly."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Srm2jsGHVM8N"
      },
      "outputs": [],
      "source": [
        "class TwoWaySignPolicy(tf_policy.TFPolicy):\n",
        "  def __init__(self, situation):\n",
        "    observation_spec = tensor_spec.BoundedTensorSpec(\n",
        "        shape=(1,), dtype=tf.int32, minimum=-2, maximum=2)\n",
        "    action_spec = tensor_spec.BoundedTensorSpec(\n",
        "        shape=(), dtype=tf.int32, minimum=0, maximum=2)\n",
        "    time_step_spec = ts.time_step_spec(observation_spec)\n",
        "    self._situation = situation\n",
        "    super(TwoWaySignPolicy, self).__init__(time_step_spec=time_step_spec,\n",
        "                                           action_spec=action_spec)\n",
        "  def _distribution(self, time_step):\n",
        "    pass\n",
        "\n",
        "  def _variables(self):\n",
        "    return [self._situation]\n",
        "\n",
        "  def _action(self, time_step, policy_state, seed):\n",
        "    sign = tf.cast(tf.sign(time_step.observation[0, 0]), dtype=tf.int32)\n",
        "    def case_unknown_fn():\n",
        "      # Choose 1 so that we get information on the sign.\n",
        "      return tf.constant(1, shape=(1,))\n",
        "\n",
        "    # Choose 0 or 2, depending on the situation and the sign of the observation.\n",
        "    def case_normal_fn():\n",
        "      return tf.constant(sign + 1, shape=(1,))\n",
        "    def case_flipped_fn():\n",
        "      return tf.constant(1 - sign, shape=(1,))\n",
        "\n",
        "    cases = [(tf.equal(self._situation, 0), case_unknown_fn),\n",
        "             (tf.equal(self._situation, 1), case_normal_fn),\n",
        "             (tf.equal(self._situation, 2), case_flipped_fn)]\n",
        "    action = tf.case(cases, exclusive=True)\n",
        "    return policy_step.PolicyStep(action, policy_state)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "r6PPdRQQbE3Q"
      },
      "source": [
        "## The Agent"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "pO8HpL0tUP32"
      },
      "source": [
        "Now it's time to define the agent that detects the sign of the environment and sets the policy appropriately."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "7f-0W0cMbS_z"
      },
      "outputs": [],
      "source": [
        "class SignAgent(tf_agent.TFAgent):\n",
        "  def __init__(self):\n",
        "    self._situation = tf.compat.v2.Variable(0, dtype=tf.int32)\n",
        "    policy = TwoWaySignPolicy(self._situation)\n",
        "    time_step_spec = policy.time_step_spec\n",
        "    action_spec = policy.action_spec\n",
        "    super(SignAgent, self).__init__(time_step_spec=time_step_spec,\n",
        "                                    action_spec=action_spec,\n",
        "                                    policy=policy,\n",
        "                                    collect_policy=policy,\n",
        "                                    train_sequence_length=None)\n",
        "\n",
        "  def _initialize(self):\n",
        "    return tf.compat.v1.variables_initializer(self.variables)\n",
        "\n",
        "  def _train(self, experience, weights=None):\n",
        "    observation = experience.observation\n",
        "    action = experience.action\n",
        "    reward = experience.reward\n",
        "\n",
        "    # We only need to change the value of the situation variable if it is\n",
        "    # unknown (0) right now, and we can infer the situation only if the\n",
        "    # observation is not 0.\n",
        "    needs_action = tf.logical_and(tf.equal(self._situation, 0),\n",
        "                                  tf.not_equal(reward, 0))\n",
        "\n",
        "\n",
        "    def new_situation_fn():\n",
        "      \"\"\"This returns either 1 or 2, depending on the signs.\"\"\"\n",
        "      return (3 - tf.sign(tf.cast(observation[0, 0, 0], dtype=tf.int32) *\n",
        "                          tf.cast(action[0, 0], dtype=tf.int32) *\n",
        "                          tf.cast(reward[0, 0], dtype=tf.int32))) / 2\n",
        "\n",
        "    new_situation = tf.cond(needs_action,\n",
        "                            new_situation_fn,\n",
        "                            lambda: self._situation)\n",
        "    tf.compat.v1.assign(self._situation, new_situation)\n",
        "    return tf_agent.LossInfo((), ())\n",
        "\n",
        "sign_agent = SignAgent()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "oyclF0ZZpW-f"
      },
      "source": [
        "In the above code, the agent defines the policy, and the variable `situation` is shared by the agent and the policy.\n",
        "\n",
        "Also, the parameter `experience` of the `_train` function is a trajectory:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "3NlF228LGoiR"
      },
      "source": [
        "# Trajectories"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "2GbBDi1iGsnN"
      },
      "source": [
        "In TF-Agents, `trajectories` are named tuples that contain samples from previous steps taken. These samples are then used by the agent to train and update the policy. In RL, trajectories must contain information about the current state, the next state, and whether the current episode has ended. Since in the Bandit world we do not need these things, we set up a helper function to create a trajectory:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "gdSG1nv-HUJq"
      },
      "outputs": [],
      "source": [
        "# We need to add another dimension here because the agent expects the\n",
        "# trajectory of shape [batch_size, time, ...], but in this tutorial we assume\n",
        "# that both batch size and time are 1. Hence all the expand_dims.\n",
        "\n",
        "def trajectory_for_bandit(initial_step, action_step, final_step):\n",
        "  return trajectory.Trajectory(observation=tf.expand_dims(initial_step.observation, 0),\n",
        "                               action=tf.expand_dims(action_step.action, 0),\n",
        "                               policy_info=action_step.info,\n",
        "                               reward=tf.expand_dims(final_step.reward, 0),\n",
        "                               discount=tf.expand_dims(final_step.discount, 0),\n",
        "                               step_type=tf.expand_dims(initial_step.step_type, 0),\n",
        "                               next_step_type=tf.expand_dims(final_step.step_type, 0))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "zFEJ8kbI_e6Q"
      },
      "source": [
        "# Training an Agent"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "0Gh-41og_hDB"
      },
      "source": [
        "Now all the pieces are ready for training our bandit agent."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "LPx43dZgoyKg"
      },
      "outputs": [],
      "source": [
        "step = two_way_tf_environment.reset()\n",
        "for _ in range(10):\n",
        "  action_step = sign_agent.collect_policy.action(step)\n",
        "  next_step = two_way_tf_environment.step(action_step.action)\n",
        "  experience = trajectory_for_bandit(step, action_step, next_step)\n",
        "  print(experience)\n",
        "  sign_agent.train(experience)\n",
        "  step = next_step\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "4iVSNiYdy4U4"
      },
      "source": [
        "From the output one can see that after the second step (unless the observation was 0 in the first step), the policy chooses the action in the right way and thus the reward collected is always non-negative."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "RCKyKEjOlOPE"
      },
      "source": [
        "# A Real Contextual Bandit Example"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ecnQwUpmllar"
      },
      "source": [
        "In the rest of this tutorial, we use the pre-implemented [environments](https://github.com/tensorflow/agents/blob/master/tf_agents/bandits/environments/) and [agents](https://github.com/tensorflow/agents/blob/master/tf_agents/bandits/agents/) of the TF-Agents Bandits library."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "oEnXUwd-nZKl"
      },
      "outputs": [],
      "source": [
        "# Imports for example.\n",
        "from tf_agents.bandits.agents import lin_ucb_agent\n",
        "from tf_agents.bandits.environments import stationary_stochastic_py_environment as sspe\n",
        "from tf_agents.bandits.metrics import tf_metrics\n",
        "from tf_agents.drivers import dynamic_step_driver\n",
        "from tf_agents.replay_buffers import tf_uniform_replay_buffer\n",
        "\n",
        "import matplotlib.pyplot as plt"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "37oy70dUmmie"
      },
      "source": [
        "## Stationary Stochastic Environment with Linear Payoff Functions"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "euPPd8x1m7iG"
      },
      "source": [
        "The environment used in this example is the [StationaryStochasticPyEnvironment](https://github.com/tensorflow/agents/blob/master/tf_agents/bandits/environments/stationary_stochastic_py_environment.py). This environment takes as parameter a (usually noisy) function for giving observations (context), and for every arm takes an (also noisy) function that computes the reward based on the given observation. In our example, we sample the context uniformly from a d-dimensional cube, and the reward functions are linear functions of the context, plus some Gaussian noise."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "gVa0hmQrpe6w"
      },
      "outputs": [],
      "source": [
        "batch_size = 2 # @param\n",
        "arm0_param = [-3, 0, 1, -2] # @param\n",
        "arm1_param = [1, -2, 3, 0] # @param\n",
        "arm2_param = [0, 0, 1, 1] # @param\n",
        "def context_sampling_fn(batch_size):\n",
        "  \"\"\"Contexts from [-10, 10]^4.\"\"\"\n",
        "  def _context_sampling_fn():\n",
        "    return np.random.randint(-10, 10, [batch_size, 4]).astype(np.float32)\n",
        "  return _context_sampling_fn\n",
        "\n",
        "class LinearNormalReward(object):\n",
        "  \"\"\"A class that acts as linear reward function when called.\"\"\"\n",
        "  def __init__(self, theta, sigma):\n",
        "    self.theta = theta\n",
        "    self.sigma = sigma\n",
        "  def __call__(self, x):\n",
        "    mu = np.dot(x, self.theta)\n",
        "    return np.random.normal(mu, self.sigma)\n",
        "\n",
        "arm0_reward_fn = LinearNormalReward(arm0_param, 1)\n",
        "arm1_reward_fn = LinearNormalReward(arm1_param, 1)\n",
        "arm2_reward_fn = LinearNormalReward(arm2_param, 1)\n",
        "\n",
        "environment = tf_py_environment.TFPyEnvironment(\n",
        "    sspe.StationaryStochasticPyEnvironment(\n",
        "        context_sampling_fn(batch_size),\n",
        "        [arm0_reward_fn, arm1_reward_fn, arm2_reward_fn],\n",
        "        batch_size=batch_size))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "haID-SPgsLyY"
      },
      "source": [
        "## The LinUCB Agent"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "298-1Q0bsQmR"
      },
      "source": [
        "The agent below implements the [LinUCB](http://rob.schapire.net/papers/www10.pdf) algorithm."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "p4XmGgIusj-K"
      },
      "outputs": [],
      "source": [
        "observation_spec = tensor_spec.TensorSpec([4], tf.float32)\n",
        "time_step_spec = ts.time_step_spec(observation_spec)\n",
        "action_spec = tensor_spec.BoundedTensorSpec(\n",
        "    dtype=tf.int32, shape=(), minimum=0, maximum=2)\n",
        "\n",
        "agent = lin_ucb_agent.LinearUCBAgent(time_step_spec=time_step_spec,\n",
        "                                     action_spec=action_spec)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Eua_aC7Rt78G"
      },
      "source": [
        "## Regret Metric"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "FBJDiJvEt-xC"
      },
      "source": [
        "Bandits' most important metric is *regret*, calculated as the difference between the reward collected by the agent and the expected reward of an oracle policy that has access to the reward functions of the environment. The [RegretMetric](https://github.com/tensorflow/agents/blob/master/tf_agents/bandits/metrics/tf_metrics.py) thus needs a *baseline_reward_fn* function that calculates the best achievable expected reward given an observation. For our example, we need to take the maximum of the no-noise equivalents of the reward functions that we already defined for the environment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "cX7MiFhNu3_L"
      },
      "outputs": [],
      "source": [
        "def compute_optimal_reward(observation):\n",
        "  expected_reward_for_arms = [\n",
        "      tf.linalg.matvec(observation, tf.cast(arm0_param, dtype=tf.float32)),\n",
        "      tf.linalg.matvec(observation, tf.cast(arm1_param, dtype=tf.float32)),\n",
        "      tf.linalg.matvec(observation, tf.cast(arm2_param, dtype=tf.float32))]\n",
        "  optimal_action_reward = tf.reduce_max(expected_reward_for_arms, axis=0)\n",
        "  return optimal_action_reward\n",
        "\n",
        "regret_metric = tf_metrics.RegretMetric(compute_optimal_reward)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "YRWz-Qeb13JC"
      },
      "source": [
        "## Training"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "khdKjTs516Pg"
      },
      "source": [
        "Now we put together all the components that we introduced above: the environment, the policy, and the agent. We run the policy on the environment and output training data with the help of a *driver*, and train the agent on the data.\n",
        "\n",
        "Note that there are two parameters that together specify the number of steps taken. `num_iterations` specifies how many times we run the trainer loop, while the driver will take `steps_per_loop` steps per iteration. The main reason behind keeping both of these parameters is that some operations are done per iteration, while some are done by the driver in every step. For example, the agent's `train` function is only called once per iteration. The trade-off here is that if we train more often then our policy is \"fresher\", on the other hand, training in bigger batches might be more time efficient."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "4Ggn45g62DWx"
      },
      "outputs": [],
      "source": [
        "num_iterations = 90 # @param\n",
        "steps_per_loop = 1 # @param\n",
        "\n",
        "replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(\n",
        "    data_spec=agent.policy.trajectory_spec,\n",
        "    batch_size=batch_size,\n",
        "    max_length=steps_per_loop)\n",
        "\n",
        "observers = [replay_buffer.add_batch, regret_metric]\n",
        "\n",
        "driver = dynamic_step_driver.DynamicStepDriver(\n",
        "    env=environment,\n",
        "    policy=agent.collect_policy,\n",
        "    num_steps=steps_per_loop * batch_size,\n",
        "    observers=observers)\n",
        "\n",
        "regret_values = []\n",
        "\n",
        "for _ in range(num_iterations):\n",
        "  driver.run()\n",
        "  loss_info = agent.train(replay_buffer.gather_all())\n",
        "  replay_buffer.clear()\n",
        "  regret_values.append(regret_metric.result())\n",
        "\n",
        "plt.plot(regret_values)\n",
        "plt.ylabel('Average Regret')\n",
        "plt.xlabel('Number of Iterations')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "J2diHS5IzLuo"
      },
      "source": [
        "After running the last code snippet, the resulting plot (hopefully) shows that the average regret is going down as the agent is trained and the policy gets better in figuring out what the right action is, given the observation."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "2qLMnOL00-2V"
      },
      "source": [
        "# What's Next?"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "FOiRWZbf1Drs"
      },
      "source": [
        "To see more working examples, please see the [bandits/agents/examples](https://github.com/tensorflow/agents/blob/master/tf_agents/bandits/agents/examples) directory that has ready-to-run examples for different agents and environments."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "bandits_tutorial.ipynb",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true,
      "version": "0.3.2"
    },
    "kernelspec": {
      "display_name": "Python 2",
      "name": "python2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
